-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Target Platform Capabilities Design #1276
Conversation
- Create a new `schema` package to house all target platform modeling classes - Introduce a new versioning system with minor and patch versions Additional Changes: - Update existing target platform models to adhere to the new versioning convention - Add necessary metadata - Correct all import statements - Update and enhance tests to reflect the design changes
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \ | ||
TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig | ||
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams | ||
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that the schema is going to be updated (maybe frequently), and MCT core would change the schema that it is using accordingly, we need to figure out how to modify these imports throughout the code without accessing ".v1" directly at each import.
Is there a way to "export" (as TPC package "API") a default schema that references the currently used version, such that all imports will point to it and when we want to change the used schema by MCT we'll only have to change in 1 place?
Maybe @irenaby would have an idea how it can be done?
Let's discuss this offline if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, this is not mandatory for this PR, but solving this here would be better, because it will save us editing all these files again in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only part of MCT that should be aware of the schema is the parser once we have it. It should parse the schema into whatever representation the rest of mct works with. If we want to reuse the same classes for now, we can add a proxy module that will only import the classes from the schema, and the rest of mct imports from that proxy module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @irenaby , we should have some proxy model at this stage...
model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py
Show resolved
Hide resolved
...l_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py
Show resolved
Hide resolved
model_compression_toolkit/target_platform_capabilities/target_platform/operators.py
Show resolved
Hide resolved
model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py
Show resolved
Hide resolved
...atform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py
Show resolved
Hide resolved
@@ -102,7 +102,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza | |||
signedness=Signedness.AUTO) | |||
|
|||
# We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes. | |||
linear_eight_bits = tp.OpQuantizationConfig( | |||
linear_eight_bits = model_compression_toolkit.target_platform_capabilities.schema.v1.OpQuantizationConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe import the path as "schema_v1" (or create an alias at the beginning of the file) instead of repeating it everywhere?
or do we want it to be explicit in the TPC model description? @haihabi
even if we want it to, eventually this file will be a JSON and not written as code using imports, so these fields in the JSON won't include the entire schema path anyway...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the JSON file should have the schema version from the constant in the schema main class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the JSON file should have the schema version from the constant in the schema main class.
So you agree? the path here can be shortened with an alias to improve file readability?
@haihabi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
tpc_patch_version = f'{tpc.tp_model.tpc_patch_version}' | ||
tpc_platform_type = f'{tpc.tp_model.tpc_platform_type}' | ||
tpc_schema = f'{tpc.tp_model.SCHEMA_VERSION}' | ||
return {MCT_VERSION: mct_version, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you need to access the fields elsewhere, why not define class or named tuple? If this goes directly into the model, no need to define global consts.
@@ -5,7 +5,7 @@ | |||
Several training methods may be applied by the user to train the QAT ready model | |||
created by `keras_quantization_aware_training_init` method in [`keras/quantization_facade`](../quantization_facade.py). | |||
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py)) | |||
and [`QuantizationMethod`](../../../target_platform_capabilities/target_platform/op_quantization_config.py) | |||
and `QuantizationMethod` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix the path according to the new file location
schema
package to house all target platform modeling classesAdditional Changes:
Pull Request Description:
Checklist before requesting a review: